{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Easy Ensemble Learning with h2oEnsemble\n",
    "\n",
    "## Introduction\n",
    "\n",
    "In statistics and machine learning, ensemble methods use multiple models to **obtain better predictive performance** than could be obtained from any of the constituent models (Wikipredia, 2015). This notebook demonstrates an easy way to carry out ensemble learning with H2O models using **`h2oEnsemble`**.\n",
    "\n",
    "### Key Benefit\n",
    "\n",
    "We give our users the ability to build, compare and stack different H2O, MXNet, TensorFlow and Caffe models quickly and easily using the H2O platform.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup\n",
    "\n",
    "We need three R packages for this demo: **`h2o`**, **`h2oEnsemble`** and **`mlbench`**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Load R Packages\n",
    "suppressPackageStartupMessages(library(h2o))\n",
    "suppressPackageStartupMessages(library(mlbench))     # for Boston Housing Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Install h2oEnsemble from GitHub if needed\n",
    "# Reference: https://github.com/h2oai/h2o-3/tree/master/h2o-r/ensemble\n",
    "if (!require(h2oEnsemble)) {\n",
    "    install.packages(\"https://h2o-release.s3.amazonaws.com/h2o-ensemble/R/h2oEnsemble_0.1.8.tar.gz\", repos = NULL)\n",
    "}\n",
    "suppressPackageStartupMessages(library(h2oEnsemble)) # for model stacking"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Boston Housing Data\n",
    "\n",
    "The dataset used in this demo is **`Boston Housing`** from **`mlbench`**, it contains housing values in suburbs of Boston.\n",
    "\n",
    "\n",
    "- **Reference**: UCI Machine Learning Repository (https://archive.ics.uci.edu/ml/datasets/Housing)\n",
    "- **Source**: This dataset was taken from the StatLib library which is maintained at Carnegie Mellon University. \n",
    "- **Creator**: Harrison, D. and Rubinfeld, D.L., 'Hedonic prices and the demand for clean air', J. Environ. Economics & Management, vol.5, 81-102, 1978.\n",
    "- **Type**: Regression\n",
    "- **Dimensions**: 506 instances, 13 numeric features and 1 numeric target.\n",
    "\n",
    "\n",
    "- **13 Features**:\n",
    "    - **CRIM**: per capita crime rate by town \n",
    "    - **ZN**: proportion of residential land zoned for lots over 25,000 sq.ft. \n",
    "    - **INDUS**: proportion of non-retail business acres per town \n",
    "    - **CHAS**: Charles River dummy variable (= 1 if tract bounds river; 0 otherwise) \n",
    "    - **NOX**: nitric oxides concentration (parts per 10 million) \n",
    "    - **RM**: average number of rooms per dwelling \n",
    "    - **AGE**: proportion of owner-occupied units built prior to 1940 \n",
    "    - **DIS**: weighted distances to five Boston employment centres \n",
    "    - **RAD**: index of accessibility to radial highways \n",
    "    - **TAX**: full-value property-tax rate per $10,000 \n",
    "    - **PTRATIO**: pupil-teacher ratio by town \n",
    "    - **B**: 1000(Bk - 0.63)^2 where Bk is the proportion of blacks by town \n",
    "    - **LSTAT**: % lower status of the population \n",
    "\n",
    "\n",
    "- **Target**:\n",
    "    - **MEDV**: Median value of owner-occupied homes in $1000's (this is the value we want to predict)\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<thead><tr><th scope=col>crim</th><th scope=col>zn</th><th scope=col>indus</th><th scope=col>chas</th><th scope=col>nox</th><th scope=col>rm</th><th scope=col>age</th><th scope=col>dis</th><th scope=col>rad</th><th scope=col>tax</th><th scope=col>ptratio</th><th scope=col>b</th><th scope=col>lstat</th><th scope=col>medv</th></tr></thead>\n",
       "<tbody>\n",
       "\t<tr><td>0.00632</td><td>18     </td><td>2.31   </td><td>0      </td><td>0.538  </td><td>6.575  </td><td>65.2   </td><td>4.0900 </td><td>1      </td><td>296    </td><td>15.3   </td><td>396.90 </td><td>4.98   </td><td>24.0   </td></tr>\n",
       "\t<tr><td>0.02731</td><td> 0     </td><td>7.07   </td><td>0      </td><td>0.469  </td><td>6.421  </td><td>78.9   </td><td>4.9671 </td><td>2      </td><td>242    </td><td>17.8   </td><td>396.90 </td><td>9.14   </td><td>21.6   </td></tr>\n",
       "\t<tr><td>0.02729</td><td> 0     </td><td>7.07   </td><td>0      </td><td>0.469  </td><td>7.185  </td><td>61.1   </td><td>4.9671 </td><td>2      </td><td>242    </td><td>17.8   </td><td>392.83 </td><td>4.03   </td><td>34.7   </td></tr>\n",
       "\t<tr><td>0.03237</td><td> 0     </td><td>2.18   </td><td>0      </td><td>0.458  </td><td>6.998  </td><td>45.8   </td><td>6.0622 </td><td>3      </td><td>222    </td><td>18.7   </td><td>394.63 </td><td>2.94   </td><td>33.4   </td></tr>\n",
       "\t<tr><td>0.06905</td><td> 0     </td><td>2.18   </td><td>0      </td><td>0.458  </td><td>7.147  </td><td>54.2   </td><td>6.0622 </td><td>3      </td><td>222    </td><td>18.7   </td><td>396.90 </td><td>5.33   </td><td>36.2   </td></tr>\n",
       "\t<tr><td>0.02985</td><td> 0     </td><td>2.18   </td><td>0      </td><td>0.458  </td><td>6.430  </td><td>58.7   </td><td>6.0622 </td><td>3      </td><td>222    </td><td>18.7   </td><td>394.12 </td><td>5.21   </td><td>28.7   </td></tr>\n",
       "</tbody>\n",
       "</table>\n"
      ],
      "text/latex": [
       "\\begin{tabular}{r|llllllllllllll}\n",
       " crim & zn & indus & chas & nox & rm & age & dis & rad & tax & ptratio & b & lstat & medv\\\\\n",
       "\\hline\n",
       "\t 0.00632 & 18      & 2.31    & 0       & 0.538   & 6.575   & 65.2    & 4.0900  & 1       & 296     & 15.3    & 396.90  & 4.98    & 24.0   \\\\\n",
       "\t 0.02731 &  0      & 7.07    & 0       & 0.469   & 6.421   & 78.9    & 4.9671  & 2       & 242     & 17.8    & 396.90  & 9.14    & 21.6   \\\\\n",
       "\t 0.02729 &  0      & 7.07    & 0       & 0.469   & 7.185   & 61.1    & 4.9671  & 2       & 242     & 17.8    & 392.83  & 4.03    & 34.7   \\\\\n",
       "\t 0.03237 &  0      & 2.18    & 0       & 0.458   & 6.998   & 45.8    & 6.0622  & 3       & 222     & 18.7    & 394.63  & 2.94    & 33.4   \\\\\n",
       "\t 0.06905 &  0      & 2.18    & 0       & 0.458   & 7.147   & 54.2    & 6.0622  & 3       & 222     & 18.7    & 396.90  & 5.33    & 36.2   \\\\\n",
       "\t 0.02985 &  0      & 2.18    & 0       & 0.458   & 6.430   & 58.7    & 6.0622  & 3       & 222     & 18.7    & 394.12  & 5.21    & 28.7   \\\\\n",
       "\\end{tabular}\n"
      ],
      "text/plain": [
       "  crim    zn indus chas nox   rm    age  dis    rad tax ptratio b      lstat\n",
       "1 0.00632 18 2.31  0    0.538 6.575 65.2 4.0900 1   296 15.3    396.90 4.98 \n",
       "2 0.02731  0 7.07  0    0.469 6.421 78.9 4.9671 2   242 17.8    396.90 9.14 \n",
       "3 0.02729  0 7.07  0    0.469 7.185 61.1 4.9671 2   242 17.8    392.83 4.03 \n",
       "4 0.03237  0 2.18  0    0.458 6.998 45.8 6.0622 3   222 18.7    394.63 2.94 \n",
       "5 0.06905  0 2.18  0    0.458 7.147 54.2 6.0622 3   222 18.7    396.90 5.33 \n",
       "6 0.02985  0 2.18  0    0.458 6.430 58.7 6.0622 3   222 18.7    394.12 5.21 \n",
       "  medv\n",
       "1 24.0\n",
       "2 21.6\n",
       "3 34.7\n",
       "4 33.4\n",
       "5 36.2\n",
       "6 28.7"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<ol class=list-inline>\n",
       "\t<li>506</li>\n",
       "\t<li>14</li>\n",
       "</ol>\n"
      ],
      "text/latex": [
       "\\begin{enumerate*}\n",
       "\\item 506\n",
       "\\item 14\n",
       "\\end{enumerate*}\n"
      ],
      "text/markdown": [
       "1. 506\n",
       "2. 14\n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "[1] 506  14"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Import data\n",
    "data(BostonHousing)\n",
    "head(BostonHousing)\n",
    "dim(BostonHousing)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Splitting Data into Training/Test Set\n",
    "\n",
    "We want to evaluate the predictive performance on a holdout dataset. The following code split the `Boston Housing` data randomly into:\n",
    "\n",
    "- Training: 400 instances\n",
    "- Test: 106 instances\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Split data\n",
    "set.seed(1234)\n",
    "row_train <- sample(1:nrow(BostonHousing), 400)\n",
    "train <- BostonHousing[row_train,]\n",
    "test <- BostonHousing[-row_train,]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<ol class=list-inline>\n",
       "\t<li>400</li>\n",
       "\t<li>14</li>\n",
       "</ol>\n"
      ],
      "text/latex": [
       "\\begin{enumerate*}\n",
       "\\item 400\n",
       "\\item 14\n",
       "\\end{enumerate*}\n"
      ],
      "text/markdown": [
       "1. 400\n",
       "2. 14\n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "[1] 400  14"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<thead><tr><th></th><th scope=col>crim</th><th scope=col>zn</th><th scope=col>indus</th><th scope=col>chas</th><th scope=col>nox</th><th scope=col>rm</th><th scope=col>age</th><th scope=col>dis</th><th scope=col>rad</th><th scope=col>tax</th><th scope=col>ptratio</th><th scope=col>b</th><th scope=col>lstat</th><th scope=col>medv</th></tr></thead>\n",
       "<tbody>\n",
       "\t<tr><th scope=row>58</th><td>0.01432</td><td>100    </td><td> 1.32  </td><td>0      </td><td>0.411  </td><td>6.816  </td><td>40.5   </td><td>8.3248 </td><td> 5     </td><td>256    </td><td>15.1   </td><td>392.90 </td><td> 3.95  </td><td>31.6   </td></tr>\n",
       "\t<tr><th scope=row>315</th><td>0.36920</td><td>  0    </td><td> 9.90  </td><td>0      </td><td>0.544  </td><td>6.567  </td><td>87.3   </td><td>3.6023 </td><td> 4     </td><td>304    </td><td>18.4   </td><td>395.69 </td><td> 9.28  </td><td>23.8   </td></tr>\n",
       "\t<tr><th scope=row>308</th><td>0.04932</td><td> 33    </td><td> 2.18  </td><td>0      </td><td>0.472  </td><td>6.849  </td><td>70.3   </td><td>3.1827 </td><td> 7     </td><td>222    </td><td>18.4   </td><td>396.90 </td><td> 7.53  </td><td>28.2   </td></tr>\n",
       "\t<tr><th scope=row>314</th><td>0.26938</td><td>  0    </td><td> 9.90  </td><td>0      </td><td>0.544  </td><td>6.266  </td><td>82.8   </td><td>3.2628 </td><td> 4     </td><td>304    </td><td>18.4   </td><td>393.39 </td><td> 7.90  </td><td>21.6   </td></tr>\n",
       "\t<tr><th scope=row>433</th><td>6.44405</td><td>  0    </td><td>18.10  </td><td>0      </td><td>0.584  </td><td>6.425  </td><td>74.8   </td><td>2.2004 </td><td>24     </td><td>666    </td><td>20.2   </td><td> 97.95 </td><td>12.03  </td><td>16.1   </td></tr>\n",
       "\t<tr><th scope=row>321</th><td>0.16760</td><td>  0    </td><td> 7.38  </td><td>0      </td><td>0.493  </td><td>6.426  </td><td>52.3   </td><td>4.5404 </td><td> 5     </td><td>287    </td><td>19.6   </td><td>396.90 </td><td> 7.20  </td><td>23.8   </td></tr>\n",
       "</tbody>\n",
       "</table>\n"
      ],
      "text/latex": [
       "\\begin{tabular}{r|llllllllllllll}\n",
       "  & crim & zn & indus & chas & nox & rm & age & dis & rad & tax & ptratio & b & lstat & medv\\\\\n",
       "\\hline\n",
       "\t58 & 0.01432 & 100     &  1.32   & 0       & 0.411   & 6.816   & 40.5    & 8.3248  &  5      & 256     & 15.1    & 392.90  &  3.95   & 31.6   \\\\\n",
       "\t315 & 0.36920 &   0     &  9.90   & 0       & 0.544   & 6.567   & 87.3    & 3.6023  &  4      & 304     & 18.4    & 395.69  &  9.28   & 23.8   \\\\\n",
       "\t308 & 0.04932 &  33     &  2.18   & 0       & 0.472   & 6.849   & 70.3    & 3.1827  &  7      & 222     & 18.4    & 396.90  &  7.53   & 28.2   \\\\\n",
       "\t314 & 0.26938 &   0     &  9.90   & 0       & 0.544   & 6.266   & 82.8    & 3.2628  &  4      & 304     & 18.4    & 393.39  &  7.90   & 21.6   \\\\\n",
       "\t433 & 6.44405 &   0     & 18.10   & 0       & 0.584   & 6.425   & 74.8    & 2.2004  & 24      & 666     & 20.2    &  97.95  & 12.03   & 16.1   \\\\\n",
       "\t321 & 0.16760 &   0     &  7.38   & 0       & 0.493   & 6.426   & 52.3    & 4.5404  &  5      & 287     & 19.6    & 396.90  &  7.20   & 23.8   \\\\\n",
       "\\end{tabular}\n"
      ],
      "text/plain": [
       "    crim    zn  indus chas nox   rm    age  dis    rad tax ptratio b      lstat\n",
       "58  0.01432 100  1.32 0    0.411 6.816 40.5 8.3248  5  256 15.1    392.90  3.95\n",
       "315 0.36920   0  9.90 0    0.544 6.567 87.3 3.6023  4  304 18.4    395.69  9.28\n",
       "308 0.04932  33  2.18 0    0.472 6.849 70.3 3.1827  7  222 18.4    396.90  7.53\n",
       "314 0.26938   0  9.90 0    0.544 6.266 82.8 3.2628  4  304 18.4    393.39  7.90\n",
       "433 6.44405   0 18.10 0    0.584 6.425 74.8 2.2004 24  666 20.2     97.95 12.03\n",
       "321 0.16760   0  7.38 0    0.493 6.426 52.3 4.5404  5  287 19.6    396.90  7.20\n",
       "    medv\n",
       "58  31.6\n",
       "315 23.8\n",
       "308 28.2\n",
       "314 21.6\n",
       "433 16.1\n",
       "321 23.8"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "      crim                zn             indus       chas         nox        \n",
       " Min.   : 0.00632   Min.   :  0.00   Min.   : 0.46   0:370   Min.   :0.3850  \n",
       " 1st Qu.: 0.07782   1st Qu.:  0.00   1st Qu.: 5.13   1: 30   1st Qu.:0.4520  \n",
       " Median : 0.24751   Median :  0.00   Median : 8.56           Median :0.5380  \n",
       " Mean   : 3.33351   Mean   : 12.01   Mean   :10.98           Mean   :0.5549  \n",
       " 3rd Qu.: 3.48946   3rd Qu.: 18.50   3rd Qu.:18.10           3rd Qu.:0.6258  \n",
       " Max.   :73.53410   Max.   :100.00   Max.   :27.74           Max.   :0.8710  \n",
       "       rm             age              dis              rad       \n",
       " Min.   :3.561   Min.   :  6.20   Min.   : 1.130   Min.   : 1.00  \n",
       " 1st Qu.:5.883   1st Qu.: 47.08   1st Qu.: 2.103   1st Qu.: 4.00  \n",
       " Median :6.205   Median : 77.75   Median : 3.239   Median : 5.00  \n",
       " Mean   :6.273   Mean   : 69.25   Mean   : 3.824   Mean   : 9.44  \n",
       " 3rd Qu.:6.626   3rd Qu.: 94.03   3rd Qu.: 5.234   3rd Qu.:24.00  \n",
       " Max.   :8.780   Max.   :100.00   Max.   :12.127   Max.   :24.00  \n",
       "      tax           ptratio            b              lstat      \n",
       " Min.   :187.0   Min.   :12.60   Min.   :  2.52   Min.   : 1.73  \n",
       " 1st Qu.:279.0   1st Qu.:17.40   1st Qu.:376.46   1st Qu.: 7.17  \n",
       " Median :330.0   Median :19.10   Median :391.99   Median :11.25  \n",
       " Mean   :404.8   Mean   :18.52   Mean   :359.94   Mean   :12.61  \n",
       " 3rd Qu.:666.0   3rd Qu.:20.20   3rd Qu.:396.54   3rd Qu.:16.43  \n",
       " Max.   :711.0   Max.   :22.00   Max.   :396.90   Max.   :37.97  \n",
       "      medv      \n",
       " Min.   : 5.00  \n",
       " 1st Qu.:17.27  \n",
       " Median :21.15  \n",
       " Mean   :22.51  \n",
       " 3rd Qu.:24.85  \n",
       " Max.   :50.00  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Training data - quick summary\n",
    "dim(train)\n",
    "head(train)\n",
    "summary(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<ol class=list-inline>\n",
       "\t<li>106</li>\n",
       "\t<li>14</li>\n",
       "</ol>\n"
      ],
      "text/latex": [
       "\\begin{enumerate*}\n",
       "\\item 106\n",
       "\\item 14\n",
       "\\end{enumerate*}\n"
      ],
      "text/markdown": [
       "1. 106\n",
       "2. 14\n",
       "\n",
       "\n"
      ],
      "text/plain": [
       "[1] 106  14"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<table>\n",
       "<thead><tr><th></th><th scope=col>crim</th><th scope=col>zn</th><th scope=col>indus</th><th scope=col>chas</th><th scope=col>nox</th><th scope=col>rm</th><th scope=col>age</th><th scope=col>dis</th><th scope=col>rad</th><th scope=col>tax</th><th scope=col>ptratio</th><th scope=col>b</th><th scope=col>lstat</th><th scope=col>medv</th></tr></thead>\n",
       "<tbody>\n",
       "\t<tr><th scope=row>2</th><td>0.02731</td><td> 0.0   </td><td>7.07   </td><td>0      </td><td>0.469  </td><td>6.421  </td><td> 78.9  </td><td>4.9671 </td><td>2      </td><td>242    </td><td>17.8   </td><td>396.90 </td><td> 9.14  </td><td>21.6   </td></tr>\n",
       "\t<tr><th scope=row>10</th><td>0.17004</td><td>12.5   </td><td>7.87   </td><td>0      </td><td>0.524  </td><td>6.004  </td><td> 85.9  </td><td>6.5921 </td><td>5      </td><td>311    </td><td>15.2   </td><td>386.71 </td><td>17.10  </td><td>18.9   </td></tr>\n",
       "\t<tr><th scope=row>13</th><td>0.09378</td><td>12.5   </td><td>7.87   </td><td>0      </td><td>0.524  </td><td>5.889  </td><td> 39.0  </td><td>5.4509 </td><td>5      </td><td>311    </td><td>15.2   </td><td>390.50 </td><td>15.71  </td><td>21.7   </td></tr>\n",
       "\t<tr><th scope=row>18</th><td>0.78420</td><td> 0.0   </td><td>8.14   </td><td>0      </td><td>0.538  </td><td>5.990  </td><td> 81.7  </td><td>4.2579 </td><td>4      </td><td>307    </td><td>21.0   </td><td>386.75 </td><td>14.67  </td><td>17.5   </td></tr>\n",
       "\t<tr><th scope=row>24</th><td>0.98843</td><td> 0.0   </td><td>8.14   </td><td>0      </td><td>0.538  </td><td>5.813  </td><td>100.0  </td><td>4.0952 </td><td>4      </td><td>307    </td><td>21.0   </td><td>394.54 </td><td>19.88  </td><td>14.5   </td></tr>\n",
       "\t<tr><th scope=row>28</th><td>0.95577</td><td> 0.0   </td><td>8.14   </td><td>0      </td><td>0.538  </td><td>6.047  </td><td> 88.8  </td><td>4.4534 </td><td>4      </td><td>307    </td><td>21.0   </td><td>306.38 </td><td>17.28  </td><td>14.8   </td></tr>\n",
       "</tbody>\n",
       "</table>\n"
      ],
      "text/latex": [
       "\\begin{tabular}{r|llllllllllllll}\n",
       "  & crim & zn & indus & chas & nox & rm & age & dis & rad & tax & ptratio & b & lstat & medv\\\\\n",
       "\\hline\n",
       "\t2 & 0.02731 &  0.0    & 7.07    & 0       & 0.469   & 6.421   &  78.9   & 4.9671  & 2       & 242     & 17.8    & 396.90  &  9.14   & 21.6   \\\\\n",
       "\t10 & 0.17004 & 12.5    & 7.87    & 0       & 0.524   & 6.004   &  85.9   & 6.5921  & 5       & 311     & 15.2    & 386.71  & 17.10   & 18.9   \\\\\n",
       "\t13 & 0.09378 & 12.5    & 7.87    & 0       & 0.524   & 5.889   &  39.0   & 5.4509  & 5       & 311     & 15.2    & 390.50  & 15.71   & 21.7   \\\\\n",
       "\t18 & 0.78420 &  0.0    & 8.14    & 0       & 0.538   & 5.990   &  81.7   & 4.2579  & 4       & 307     & 21.0    & 386.75  & 14.67   & 17.5   \\\\\n",
       "\t24 & 0.98843 &  0.0    & 8.14    & 0       & 0.538   & 5.813   & 100.0   & 4.0952  & 4       & 307     & 21.0    & 394.54  & 19.88   & 14.5   \\\\\n",
       "\t28 & 0.95577 &  0.0    & 8.14    & 0       & 0.538   & 6.047   &  88.8   & 4.4534  & 4       & 307     & 21.0    & 306.38  & 17.28   & 14.8   \\\\\n",
       "\\end{tabular}\n"
      ],
      "text/plain": [
       "   crim    zn   indus chas nox   rm    age   dis    rad tax ptratio b     \n",
       "2  0.02731  0.0 7.07  0    0.469 6.421  78.9 4.9671 2   242 17.8    396.90\n",
       "10 0.17004 12.5 7.87  0    0.524 6.004  85.9 6.5921 5   311 15.2    386.71\n",
       "13 0.09378 12.5 7.87  0    0.524 5.889  39.0 5.4509 5   311 15.2    390.50\n",
       "18 0.78420  0.0 8.14  0    0.538 5.990  81.7 4.2579 4   307 21.0    386.75\n",
       "24 0.98843  0.0 8.14  0    0.538 5.813 100.0 4.0952 4   307 21.0    394.54\n",
       "28 0.95577  0.0 8.14  0    0.538 6.047  88.8 4.4534 4   307 21.0    306.38\n",
       "   lstat medv\n",
       "2   9.14 21.6\n",
       "10 17.10 18.9\n",
       "13 15.71 21.7\n",
       "18 14.67 17.5\n",
       "24 19.88 14.5\n",
       "28 17.28 14.8"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "      crim                zn             indus        chas         nox        \n",
       " Min.   : 0.00906   Min.   : 0.000   Min.   : 0.740   0:101   Min.   :0.4000  \n",
       " 1st Qu.: 0.09535   1st Qu.: 0.000   1st Qu.: 5.945   1:  5   1st Qu.:0.4480  \n",
       " Median : 0.30770   Median : 0.000   Median :10.300           Median :0.5350  \n",
       " Mean   : 4.67018   Mean   : 8.929   Mean   :11.720           Mean   :0.5540  \n",
       " 3rd Qu.: 4.86247   3rd Qu.: 0.000   3rd Qu.:18.100           3rd Qu.:0.6128  \n",
       " Max.   :88.97620   Max.   :95.000   Max.   :27.740           Max.   :0.8710  \n",
       "       rm             age              dis             rad        \n",
       " Min.   :4.926   Min.   :  2.90   Min.   :1.202   Min.   : 1.000  \n",
       " 1st Qu.:5.910   1st Qu.: 37.98   1st Qu.:2.084   1st Qu.: 4.000  \n",
       " Median :6.231   Median : 76.35   Median :3.117   Median : 5.000  \n",
       " Mean   :6.330   Mean   : 66.01   Mean   :3.686   Mean   : 9.962  \n",
       " 3rd Qu.:6.562   3rd Qu.: 94.35   3rd Qu.:4.906   3rd Qu.:24.000  \n",
       " Max.   :8.398   Max.   :100.00   Max.   :9.188   Max.   :24.000  \n",
       "      tax           ptratio            b              lstat       \n",
       " Min.   :193.0   Min.   :13.00   Min.   :  0.32   Min.   : 2.960  \n",
       " 1st Qu.:287.5   1st Qu.:16.60   1st Qu.:368.61   1st Qu.: 6.758  \n",
       " Median :367.5   Median :18.40   Median :389.75   Median :11.690  \n",
       " Mean   :421.3   Mean   :18.23   Mean   :344.37   Mean   :12.806  \n",
       " 3rd Qu.:666.0   3rd Qu.:20.20   3rd Qu.:395.49   3rd Qu.:17.407  \n",
       " Max.   :711.0   Max.   :21.20   Max.   :396.90   Max.   :30.810  \n",
       "      medv      \n",
       " Min.   : 5.00  \n",
       " 1st Qu.:15.72  \n",
       " Median :21.45  \n",
       " Mean   :22.61  \n",
       " 3rd Qu.:26.57  \n",
       " Max.   :50.00  "
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Test data - quick summary\n",
    "dim(test)\n",
    "head(test)\n",
    "summary(test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training Different Regression Models\n",
    "\n",
    "We are now ready to train regression models using different algorithms in H2O. \n",
    "\n",
    "- First of all, we convert R data frames into H2O data frames. \n",
    "- Then, we define the names of features and target.\n",
    "- Finally, we train two different models:\n",
    "        - H2O Gradient Boosting Machines (CPU)\n",
    "        - H2O Distributed Random Forest (CPU)\n",
    "\n",
    "**Note 1**: Although the three algorithms used in this example are different, the core parameters are consistent (see below). This allows H2O users to get quick and easy access to different existing (and future) algorithms with a very shallow learning curve. **The core parameters** are:\n",
    "    - x = features\n",
    "    - y = target\n",
    "    - training_frame = h_train\n",
    "    \n",
    "**Note 2**: For model stacking, we need to generate holdout predictions from cross-validation. **The parameters required for model stacking** are:\n",
    "    - nfolds = 5\n",
    "    - fold_assignment = 'Modulo'\n",
    "    - keep_cross_validation_predictions = TRUE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Convert R data frames into H2O data frames\n",
    "h_train <- as.h2o(train)\n",
    "h_test <- as.h2o(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " [1] \"crim\"    \"zn\"      \"indus\"   \"chas\"    \"nox\"     \"rm\"      \"age\"    \n",
      " [8] \"dis\"     \"rad\"     \"tax\"     \"ptratio\" \"b\"       \"lstat\"  \n"
     ]
    }
   ],
   "source": [
    "# Regression - define features (x) and target (y)\n",
    "target <- \"medv\"\n",
    "features <- setdiff(colnames(train), target)\n",
    "print(features)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### H2O GBM model\n",
    "\n",
    "For more information, enter **`?h2o.gbm`** in R to look at the full list of parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Train a H2O GBM model\n",
    "model_gbm <- h2o.gbm(x = features, y = target,\n",
    "                     training_frame = h_train,\n",
    "                     model_id = \"h2o_gbm\",\n",
    "                     learn_rate = 0.1,\n",
    "                     learn_rate_annealing = 0.99,\n",
    "                     sample_rate = 0.8,\n",
    "                     col_sample_rate = 0.8,\n",
    "                     nfolds = 5,\n",
    "                     fold_assignment = \"Modulo\",\n",
    "                     keep_cross_validation_predictions = TRUE,\n",
    "                     ntrees = 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### H2O DRF model\n",
    "\n",
    "For more information, enter **`?h2o.randomForest`** in R to look at the full list of parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Train a H2O DRF model\n",
    "model_drf <- h2o.randomForest(x = features, y = target,\n",
    "                              training_frame = h_train,\n",
    "                              model_id = \"h2o_drf\",\n",
    "                              nfolds = 5,\n",
    "                              fold_assignment = \"Modulo\",\n",
    "                              keep_cross_validation_predictions = TRUE,\n",
    "                              ntrees = 100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Stacking\n",
    "\n",
    "Now we have three different models, we are ready to carry out model stacking.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Create a list to include all the models for stacking\n",
    "models <- list(model_dw, model_gbm, model_drf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Define a metalearner (one of the H2O supervised machine learning algorithms)\n",
    "metalearner <- \"h2o.glm.wrapper\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] \"Metalearning\"\n"
     ]
    }
   ],
   "source": [
    "# Use h2o.stack() to carry out metalearning\n",
    "stack <- h2o.stack(models = models, \n",
    "                   response_frame = h_train$medv,\n",
    "                   metalearner = metalearner)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "Base learner performance, sorted by specified metric:\n",
       "        learner      MSE\n",
       "1 h2o_deepwater 8.377644\n",
       "2       h2o_gbm 8.106541\n",
       "3       h2o_drf 7.443517\n",
       "\n",
       "\n",
       "H2O Ensemble Performance on <newdata>:\n",
       "----------------\n",
       "Family: gaussian\n",
       "\n",
       "Ensemble performance (MSE): 5.80436983051916\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Finally, we evaluate the predictive performance on the ensemble as well as indiviudal models.\n",
    "h2o.ensemble_performance(stack, newdata = h_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Use the ensemble to make predictions\n",
    "yhat_test <- predict(stack, h_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "R",
   "language": "R",
   "name": "ir"
  },
  "language_info": {
   "codemirror_mode": "r",
   "file_extension": ".r",
   "mimetype": "text/x-r-source",
   "name": "R",
   "pygments_lexer": "r",
   "version": "3.2.3"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
